﻿#pragma kernel UpdateLineMesh

#include "PathFrame.cginc"

struct smootherPathData
{
    uint smoothing;
    float decimation;
    float twist;
    float restLength;
    float smoothLength;
    uint usesOrientedParticles;
};

struct lineMeshData
{
    float2 uvScale;
    float thicknessScale;
    float uvAnchor;
    uint normalizeV;
};

StructuredBuffer<int> pathSmootherIndices;
StructuredBuffer<int> chunkOffsets;

StructuredBuffer<pathFrame> frames;
StructuredBuffer<int> frameOffsets;
StructuredBuffer<int> frameCounts;

StructuredBuffer<int> vertexOffsets;
StructuredBuffer<int> triangleOffsets;
StructuredBuffer<int> triangleCounts;

StructuredBuffer<lineMeshData> rendererData;
StructuredBuffer<smootherPathData> pathData;

RWByteAddressBuffer vertices;
RWByteAddressBuffer tris;

// Variables set from the CPU
uint firstRenderer;
uint rendererCount;
float3 localSpaceCamera;

pathFrame LookAt(pathFrame frame, in pathFrame target, out float dist)
{
    float3 tangent = target.position - frame.position;
    dist = length(tangent);
    tangent /= dist + EPSILON;

    quaternion rotQ = from_to_rotation(frame.tangent, tangent);
    frame.normal = rotate_vector(rotQ, frame.normal);
    frame.binormal = rotate_vector(rotQ, frame.binormal);
    frame.tangent = tangent;

    return frame;
}

[numthreads(128, 1, 1)]
void UpdateLineMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int u = id.x;
    if (u >= rendererCount) return;

    int k = firstRenderer + u;
    int s = pathSmootherIndices[k];
    
    float3 vertex = float3(0,0,0);
    float3 normal = float3(0,0,0);
    float4 bitangent = FLOAT4_ZERO;

    int tri = 0;
    int sectionIndex = 0;
    int firstVertex = vertexOffsets[k];
    int firstTriangle = triangleOffsets[k];

    float smoothLength = 0;
    int i = 0;
    for (i = chunkOffsets[s]; i < chunkOffsets[s + 1]; ++i)
        smoothLength += pathData[i].smoothLength;

    float vCoord = -rendererData[k].uvScale.y * pathData[chunkOffsets[s]].restLength * rendererData[k].uvAnchor;
    float actualToRestLengthRatio = smoothLength / pathData[chunkOffsets[s]].restLength;

    // clear out triangle indices for this rope:
    for (i = firstTriangle; i < firstTriangle + triangleCounts[k]; ++i)
    {
        int offset = i*3;
        tris.Store((offset)<<2,  0);
        tris.Store((offset+1)<<2,0);
        tris.Store((offset+2)<<2,0);
    }
    
    // for each chunk in the rope:
    for (i = chunkOffsets[s]; i < chunkOffsets[s+1]; ++i)
    {
        int firstFrame = frameOffsets[i];
        int frameCount = frameCounts[i];
        
        for (int f = 0; f < frameCount; ++f)
        {
            // Calculate previous and next curve indices:
            int prevIndex = firstFrame + max(f - 1, 0);
            int index = firstFrame + f;

            // advance v texcoord:
            vCoord += rendererData[k].uvScale.y * (distance(frames[index].position, frames[prevIndex].position) /
                                              (rendererData[k].normalizeV == 1 ? smoothLength : actualToRestLengthRatio));

            // calculate section thickness and scale the basis vectors by it:
            float sectionThickness = frames[index].thickness * rendererData[k].thicknessScale;

            normal.x = frames[index].position.x - localSpaceCamera.x;
            normal.y = frames[index].position.y - localSpaceCamera.y;
            normal.z = frames[index].position.z - localSpaceCamera.z;
            normal = normalize(normal);

            bitangent.x = -(normal.y * frames[index].tangent.z - normal.z * frames[index].tangent.y);
            bitangent.y = -(normal.z * frames[index].tangent.x - normal.x * frames[index].tangent.z);
            bitangent.z = -(normal.x * frames[index].tangent.y - normal.y * frames[index].tangent.x);
            bitangent.xyz = normalize(bitangent.xyz);
            bitangent.w = 1;

            vertex.x = frames[index].position.x - bitangent.x * sectionThickness;
            vertex.y = frames[index].position.y - bitangent.y * sectionThickness;
            vertex.z = frames[index].position.z - bitangent.z * sectionThickness;
            
            int base = (firstVertex + sectionIndex * 2) * 16;
            vertices.Store3( base<<2, asuint(vertex));
            vertices.Store3((base + 3)<<2, asuint(-normal));
            vertices.Store4((base + 6)<<2, asuint(bitangent));
            vertices.Store4((base + 10)<<2, asuint(frames[index].color));
            vertices.Store2((base + 14)<<2, asuint(float2(0, vCoord)));

            vertex.x = frames[index].position.x + bitangent.x * sectionThickness;
            vertex.y = frames[index].position.y + bitangent.y * sectionThickness;
            vertex.z = frames[index].position.z + bitangent.z * sectionThickness;
            
            base = (firstVertex + sectionIndex * 2 + 1) * 16;
            vertices.Store3( base<<2, asuint(vertex));
            vertices.Store3((base + 3)<<2, asuint(-normal));
            vertices.Store4((base + 6)<<2, asuint(bitangent));
            vertices.Store4((base + 10)<<2, asuint(frames[index].color));
            vertices.Store2((base + 14)<<2, asuint(float2(1, vCoord)));

            if (f < frameCount - 1)
            {
           
                int offset = firstTriangle * 3;
                tris.Store((offset + tri++)<<2, asuint(firstVertex + sectionIndex * 2));
                tris.Store((offset + tri++)<<2, asuint(firstVertex + (sectionIndex + 1)  * 2));
                tris.Store((offset + tri++)<<2, asuint(firstVertex + sectionIndex * 2 + 1));

                tris.Store((offset + tri++)<<2, asuint(firstVertex + sectionIndex * 2 + 1));
                tris.Store((offset + tri++)<<2, asuint(firstVertex + (sectionIndex + 1) * 2));
                tris.Store((offset + tri++)<<2, asuint(firstVertex + (sectionIndex + 1) * 2 + 1));
            }

            sectionIndex++;
        }
    }
}